package hex.gam;


import hex.genmodel.algos.gam.MSplines;
import hex.glm.GLMModel;
import jsr166y.ThreadLocalRandom;
import org.junit.Test;
import org.junit.runner.RunWith;
import water.DKV;
import water.Scope;
import water.TestUtil;
import water.fvec.Frame;
import water.runner.CloudSize;
import water.runner.H2ORunner;
import water.util.ArrayUtils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;

import static hex.gam.GAMModel.adaptValidFrame;
import static hex.gam.GamBasicISplineTest.assert2DArrayEqual;
import static hex.gam.GamISplineTest.*;
import static hex.gam.GamTestPiping.genFrameKnots;
import static hex.genmodel.algos.gam.GamUtilsISplines.extractKnots;
import static hex.genmodel.algos.gam.GamUtilsISplines.fillKnots;
import static hex.glm.GLMModel.GLMParameters.Family.binomial;
import static hex.glm.GLMModel.GLMParameters.Family.gaussian;
import static org.junit.Assert.assertArrayEquals;

@RunWith(H2ORunner.class)
@CloudSize(1)
public class GamMSplineTest  extends TestUtil {
    public static final double EPS = 1e-6;

    /**
     * Test correct gamification of gam columns when there are thin plate and M-splines.  To check for correct 
     * implementation, I compare the gamified columns when all gam columns are specified all at once to the gamification
     * columns generated by a single gam column at a time.
     */
    @Test
    public void testTPMSTransform() {
        Scope.enter();
        try {
            Frame train = parseAndTrackTestFile("smalldata/gam_test/synthetic_20Cols_binomial_20KRows.csv");
            train.replace(train.numCols() - 1, train.vec("response").toCategoricalVec()).remove();
            DKV.put(train);
            Frame allGamifiedColumns = extractGamifiedColumns(train, new String[][]{{"c_0", "c_1"}, {"c_2"}, {"c_3"},
                            {"c_4", "c_5", "c_6"}, {"c_7"}, {"c_7", "c_8", "c_9"}}, new int[]{-1, 2, 3, -1, 4, -1},
                    new double[]{0.001, 0.001, 0.001, 0.001, 0.001, 0.001}, new int[]{1, 3, 3, 1, 3, 1},
                    new int[]{11, 5, 6, 12, 6, 13}, null, binomial);
            List<String> colNames = new ArrayList<>(Arrays.asList(train.names()));
            colNames.remove(colNames.size() - 1); // remove response name
            Frame tpc0c1 = extractGamifiedColumns(train, new String[][]{{"c_0", "c_1"}}, null,
                    new double[]{0.001}, new int[]{1}, new int[]{11}, ignoredCols(colNames,
                            Arrays.asList("c_0", "c_1")), binomial);
            assertCorrectGamification(allGamifiedColumns, tpc0c1);
            Frame isc2 = extractGamifiedColumns(train, new String[][]{{"c_2"}}, new int[]{2}, new double[]{0.001},
                    new int[]{3}, new int[]{5}, ignoredCols(colNames, Arrays.asList("c_2")), binomial);
            assertCorrectGamification(allGamifiedColumns, isc2);
            Frame isc3 = extractGamifiedColumns(train, new String[][]{{"c_3"}}, new int[]{3}, new double[]{0.001},
                    new int[]{3}, new int[]{6}, ignoredCols(colNames, Arrays.asList("c_3")), binomial);
            assertCorrectGamification(allGamifiedColumns, isc3);
            Frame tpc4c5c6 = extractGamifiedColumns(train, new String[][]{{"c_4", "c_5", "c_6"}}, null,
                    new double[]{0.001}, new int[]{1}, new int[]{12}, ignoredCols(colNames, Arrays.asList("c_4", "c_5",
                            "c_6")), binomial);
            assertCorrectGamification(allGamifiedColumns, tpc4c5c6);
            Frame isc7 = extractGamifiedColumns(train, new String[][]{{"c_7"}}, new int[]{4}, new double[]{0.001},
                    new int[]{3}, new int[]{6}, ignoredCols(colNames, Arrays.asList("c_7")), binomial);
            assertCorrectGamification(allGamifiedColumns, isc7);
            Frame tpc7c8c9 = extractGamifiedColumns(train, new String[][]{{"c_7", "c_8", "c_9"}}, null,
                    new double[]{0.001}, new int[]{1}, new int[]{13}, ignoredCols(colNames, Arrays.asList("c_7", "c_8",
                            "c_9")), binomial);
            assertCorrectGamification(allGamifiedColumns, tpc7c8c9);
        } finally {
            Scope.exit();
        }
    }
    
    /***
     * Test correct gamification of data when there are CS splines and I-splines are specified
     */
    @Test
    public void testCSMSTransform() {
        Scope.enter();
        try {
            Frame train = parseAndTrackTestFile("smalldata/gam_test/synthetic_20Cols_gaussian_20KRows.csv");
            Frame allGamifiedColumns = extractGamifiedColumns(train, new String[][]{{"c_0"}, {"c_1"}, {"c_2"}, {"c_2"},
                            {"c_3"}, {"c_5"}, {"c_6"}, {"c_7"}, {"c_8"}, {"c_9"}}, new int[]{-1, 2, 3, -1, -1, 4, 5, 6,
                            -1, -1}, new double[]{0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
                            0.001}, new int[]{0, 3, 3, 0, 0, 3, 3, 3, 0, 0}, new int[]{5, 6, 7, 8, 9, 10, 9, 8, 7, 6},
                    null, gaussian);
            List<String> colNames = new ArrayList<>(Arrays.asList(train.names()));
            colNames.remove(colNames.size() - 1); // remove response name
            Frame csc_0 = extractGamifiedColumns(train, new String[][]{{"c_0"}}, null,
                    new double[]{0.001}, new int[]{0}, new int[]{5}, ignoredCols(colNames,
                            Arrays.asList("c_0")), gaussian);
            assertCorrectGamification(allGamifiedColumns,  csc_0);
            Frame isc_1  = extractGamifiedColumns(train, new String[][]{{"c_1"}}, new int[]{2}, new double[]{0.001},
                    new int[]{3}, new int[]{6}, ignoredCols(colNames, Arrays.asList("c_1")), gaussian);
            assertCorrectGamification(allGamifiedColumns, isc_1);
            Frame isc_2  = extractGamifiedColumns(train, new String[][]{{"c_2"}}, new int[]{3}, new double[]{0.001},
                    new int[]{3}, new int[]{7}, ignoredCols(colNames, Arrays.asList("c_2")), gaussian);
            assertCorrectGamification(allGamifiedColumns, isc_2);
            Frame csc_2  = extractGamifiedColumns(train, new String[][]{{"c_2"}}, new int[]{-1}, new double[]{0.001},
                    new int[]{0}, new int[]{8}, ignoredCols(colNames, Arrays.asList("c_2")), gaussian);
            assertCorrectGamification(allGamifiedColumns, csc_2);
            Frame csc_3  = extractGamifiedColumns(train, new String[][]{{"c_3"}}, new int[]{-1}, new double[]{0.001},
                    new int[]{0}, new int[]{9}, ignoredCols(colNames, Arrays.asList("c_3")), gaussian);
            assertCorrectGamification(allGamifiedColumns, csc_3);
            Frame isc_5  = extractGamifiedColumns(train, new String[][]{{"c_5"}}, new int[]{4}, new double[]{0.001},
                    new int[]{3}, new int[]{10}, ignoredCols(colNames, Arrays.asList("c_5")), gaussian);
            assertCorrectGamification(allGamifiedColumns, isc_5);
            Frame isc_6  = extractGamifiedColumns(train, new String[][]{{"c_6"}}, new int[]{5}, new double[]{0.001},
                    new int[]{3}, new int[]{9}, ignoredCols(colNames, Arrays.asList("c_6")), gaussian);
            assertCorrectGamification(allGamifiedColumns, isc_6);
            Frame isc_7  = extractGamifiedColumns(train, new String[][]{{"c_7"}}, new int[]{6}, new double[]{0.001},
                    new int[]{3}, new int[]{8}, ignoredCols(colNames, Arrays.asList("c_7")), gaussian);
            assertCorrectGamification(allGamifiedColumns, isc_7);
            Frame csc_8  = extractGamifiedColumns(train, new String[][]{{"c_8"}}, new int[]{-1}, new double[]{0.001},
                    new int[]{0}, new int[]{7}, ignoredCols(colNames, Arrays.asList("c_8")), gaussian);
            assertCorrectGamification(allGamifiedColumns, csc_8);
            Frame csc_9  = extractGamifiedColumns(train, new String[][]{{"c_9"}}, new int[]{-1}, new double[]{0.001},
                    new int[]{0}, new int[]{6}, ignoredCols(colNames, Arrays.asList("c_9")), gaussian);
            assertCorrectGamification(allGamifiedColumns, csc_9);
        } finally {
            Scope.exit();
        }
    }

    /***
     * Test correct gamification of data when there are thin plate splines, CS splines and I-splines are specified
     */
    @Test
    public void testMSCSTPTransform() {
        Scope.enter();
        try {
            Frame train = parseAndTrackTestFile("smalldata/gam_test/synthetic_20Cols_gaussian_20KRows.csv");
            Frame allGamifiedColumns = extractGamifiedColumns(train, new String[][]{{"c_0"}, {"c_1", "c_2"}, {"c_2"},
                            {"c_3"}, {"c_5"}, {"c_6", "c_7", "c_8"}, {"c_9"}, {"c_9"}, {"c_9"}}, new int[]{-1, -1, 9,
                            -1, 10, -1, -1, -1, 8}, new double[]{0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
                            0.001, 0.001}, new int[]{0, 1, 3, 0, 3, 1, 0, 1, 3}, new int[]{5, 11, 7, 8, 9, 12, 8, 11, 6},
                    null, gaussian);
            List<String> colNames = new ArrayList<>(Arrays.asList(train.names()));
            colNames.remove(colNames.size() - 1); // remove response name
            Frame csc_0 = extractGamifiedColumns(train, new String[][]{{"c_0"}}, null,
                    new double[]{0.001}, new int[]{0}, new int[]{5}, ignoredCols(colNames,
                            Arrays.asList("c_0")), gaussian);
            assertCorrectGamification(allGamifiedColumns,  csc_0);
            Frame tpc1c2 = extractGamifiedColumns(train, new String[][]{{"c_1", "c_2"}}, null,
                    new double[]{0.001}, new int[]{1}, new int[]{11}, ignoredCols(colNames,
                            Arrays.asList("c_1", "c_2")), gaussian);
            assertCorrectGamification(allGamifiedColumns, tpc1c2);
            Frame msc_2  = extractGamifiedColumns(train, new String[][]{{"c_2"}}, new int[]{9}, new double[]{0.001},
                    new int[]{3}, new int[]{7}, ignoredCols(colNames, Arrays.asList("c_2")), gaussian);
            assertCorrectGamification(allGamifiedColumns, msc_2);
            Frame csc_3 = extractGamifiedColumns(train, new String[][]{{"c_3"}}, null,
                    new double[]{0.001}, new int[]{0}, new int[]{8}, ignoredCols(colNames,
                            Arrays.asList("c_3")), gaussian);
            assertCorrectGamification(allGamifiedColumns,  csc_3);
            Frame msc_5  = extractGamifiedColumns(train, new String[][]{{"c_5"}}, new int[]{10}, new double[]{0.001},
                    new int[]{3}, new int[]{9}, ignoredCols(colNames, Arrays.asList("c_5")), gaussian);
            assertCorrectGamification(allGamifiedColumns, msc_5);
            Frame tpc6c7c8 = extractGamifiedColumns(train, new String[][]{{"c_6", "c_7", "c_8"}}, null,
                    new double[]{0.001}, new int[]{1}, new int[]{12}, ignoredCols(colNames,
                            Arrays.asList("c_6", "c_7", "c_8")), gaussian);
            assertCorrectGamification(allGamifiedColumns, tpc6c7c8);
            Frame csc_9 = extractGamifiedColumns(train, new String[][]{{"c_9"}}, null,
                    new double[]{0.001}, new int[]{0}, new int[]{8}, ignoredCols(colNames,
                            Arrays.asList("c_9")), gaussian);
            assertCorrectGamification(allGamifiedColumns,  csc_9);
            Frame tpc9 = extractGamifiedColumns(train, new String[][]{{"c_9"}}, null,
                    new double[]{0.001}, new int[]{1}, new int[]{11}, ignoredCols(colNames,
                            Arrays.asList("c_9")), gaussian);
            assertCorrectGamification(allGamifiedColumns, tpc9);
            Frame msc_9  = extractGamifiedColumns(train, new String[][]{{"c_9"}}, new int[]{8}, new double[]{0.001},
                    new int[]{3}, new int[]{6}, ignoredCols(colNames, Arrays.asList("c_9")), gaussian);
            assertCorrectGamification(allGamifiedColumns, msc_9);
        } finally {
            Scope.exit();
        }
    }

    /***
     * Test correct gamification of data when there are thin plate splines, M-splines, CS splines and I-splines are
     * specified
     */
    @Test
    public void testMSISCSTPTransform() {
        Scope.enter();
        try {
            Frame train = parseAndTrackTestFile("smalldata/gam_test/synthetic_20Cols_gaussian_20KRows.csv");
            Frame allGamifiedColumns = extractGamifiedColumns(train, new String[][]{{"c_0"}, {"c_1", "c_2"}, {"c_2"},
                            {"c_3"}, {"c_5"}, {"c_6", "c_7", "c_8"}, {"c_9"}, {"c_9"}, {"c_9"}}, new int[]{-1, -1, 9,
                            4, 10, -1, -1, -1, 8}, new double[]{0.001, 0.001, 0.001, 0.001, 0.001, 0.001, 0.001,
                            0.001, 0.001}, new int[]{0, 1, 3, 2, 3, 1, 0, 1, 3}, new int[]{5, 11, 7, 8, 9, 12, 8, 11, 6},
                    null, gaussian);
            List<String> colNames = new ArrayList<>(Arrays.asList(train.names()));
            colNames.remove(colNames.size() - 1); // remove response name
            Frame csc_0 = extractGamifiedColumns(train, new String[][]{{"c_0"}}, null,
                    new double[]{0.001}, new int[]{0}, new int[]{5}, ignoredCols(colNames,
                            Arrays.asList("c_0")), gaussian);
            assertCorrectGamification(allGamifiedColumns,  csc_0);
            Frame tpc1c2 = extractGamifiedColumns(train, new String[][]{{"c_1", "c_2"}}, null,
                    new double[]{0.001}, new int[]{1}, new int[]{11}, ignoredCols(colNames,
                            Arrays.asList("c_1", "c_2")), gaussian);
            assertCorrectGamification(allGamifiedColumns, tpc1c2);
            Frame msc_2  = extractGamifiedColumns(train, new String[][]{{"c_2"}}, new int[]{9}, new double[]{0.001},
                    new int[]{3}, new int[]{7}, ignoredCols(colNames, Arrays.asList("c_2")), gaussian);
            assertCorrectGamification(allGamifiedColumns, msc_2);
            Frame isc_3 = extractGamifiedColumns(train, new String[][]{{"c_3"}}, new int[]{4},
                    new double[]{0.001}, new int[]{2}, new int[]{8}, ignoredCols(colNames,
                            Arrays.asList("c_3")), gaussian);
            assertCorrectGamification(allGamifiedColumns,  isc_3);
            Frame msc_5  = extractGamifiedColumns(train, new String[][]{{"c_5"}}, new int[]{10}, new double[]{0.001},
                    new int[]{3}, new int[]{9}, ignoredCols(colNames, Arrays.asList("c_5")), gaussian);
            assertCorrectGamification(allGamifiedColumns, msc_5);
            Frame tpc6c7c8 = extractGamifiedColumns(train, new String[][]{{"c_6", "c_7", "c_8"}}, null,
                    new double[]{0.001}, new int[]{1}, new int[]{12}, ignoredCols(colNames,
                            Arrays.asList("c_6", "c_7", "c_8")), gaussian);
            assertCorrectGamification(allGamifiedColumns, tpc6c7c8);
            Frame csc_9 = extractGamifiedColumns(train, new String[][]{{"c_9"}}, null,
                    new double[]{0.001}, new int[]{0}, new int[]{8}, ignoredCols(colNames,
                            Arrays.asList("c_9")), gaussian);
            assertCorrectGamification(allGamifiedColumns,  csc_9);
            Frame tpc9 = extractGamifiedColumns(train, new String[][]{{"c_9"}}, null,
                    new double[]{0.001}, new int[]{1}, new int[]{11}, ignoredCols(colNames,
                            Arrays.asList("c_9")), gaussian);
            assertCorrectGamification(allGamifiedColumns, tpc9);
            Frame msc_9  = extractGamifiedColumns(train, new String[][]{{"c_9"}}, new int[]{8}, new double[]{0.001},
                    new int[]{3}, new int[]{6}, ignoredCols(colNames, Arrays.asList("c_9")), gaussian);
            assertCorrectGamification(allGamifiedColumns, msc_9);
        } finally {
            Scope.exit();
        }
    }
    
    
    /**
     * Make sure the validation dataset is gamified correctly.  Gamification of validation dataset and training
     * dataset uses different methods.  
     */
    @Test
    public void testGamificationValid() {
        Scope.enter();
        try {
            Frame train = Scope.track(generateRealWithRangeOnly(4, 100, 0, 12345,
                    4)); // generate training frame
            Frame test = Scope.track(generateRealWithRangeOnly(4, 100, 0, 12345,
                    4)); // generate test frame that is exactly the same as train
            // generate knots frames
            int numRow = 4;
            double[][] pctiles0 = new double[numRow][1];
            double[][] pctiles1 = new double[numRow][1];
            double[][] pctiles2 = new double[numRow][1];
            pctiles0[numRow - 1][0] = train.vec(0).max();
            pctiles1[numRow - 1][0] = train.vec(1).max();
            pctiles2[numRow - 1][0] = train.vec(2).max();
            pctiles0[0][0] = train.vec(0).min();
            pctiles1[0][0] = train.vec(1).min();
            pctiles2[0][0] = train.vec(2).min();
            pctiles0[1][0] = train.vec(0).min()*0.75+train.vec(0).max()*0.25;
            pctiles1[1][0] = train.vec(1).min()*0.75+train.vec(1).max()*0.25;
            pctiles2[1][0] = train.vec(2).min()*0.75+train.vec(2).max()*0.25;
            pctiles0[2][0] = (train.vec(0).min()+train.vec(0).max())*0.5;
            pctiles1[2][0] = (train.vec(1).min()+train.vec(1).max())*0.5;
            pctiles2[2][0] = (train.vec(2).min()+train.vec(2).max())*0.5;

            Frame knotsFrame1 = genFrameKnots(pctiles0);
            DKV.put(knotsFrame1);
            Scope.track(knotsFrame1);
            Frame knotsFrame2 = genFrameKnots(pctiles1);
            DKV.put(knotsFrame2);
            Scope.track(knotsFrame2);
            Frame knotsFrame3 = genFrameKnots(pctiles2);
            DKV.put(knotsFrame3);
            Scope.track(knotsFrame3);
            // generate gamified frame
            String[][] gamCols = new String[][]{{"C1"}, {"C2"}, {"C3"}};
            GAMModel.GAMParameters params = new GAMModel.GAMParameters();
            params._scale = new double[]{0.1, 0.1, 0.1};
            params._bs = new int[]{3, 3, 3};
            params._family = gaussian;
            params._response_column = "C4";
            params._spline_orders = new int[]{2, 2, 2};
            params._max_iterations = 1;
            params._savePenaltyMat = true;
            params._gam_columns = gamCols;
            params._knot_ids = new String[]{knotsFrame1._key.toString(), knotsFrame2._key.toString(),
                    knotsFrame3._key.toString()};
            params._train = train._key;
            params._solver = GLMModel.GLMParameters.Solver.IRLSM;
            params._keep_gam_cols = true;
            final GAMModel gam = new GAM(params).trainModel().get();
            Scope.track_generic(gam);
            Frame validGamified = adaptValidFrame(test, test, params, gam._output._gamColNames, null,
                    gam._output._zTranspose, gam._output._knots, null, null, 
                    null, null, new int[]{0,0,3});
            DKV.put(validGamified);
            Scope.track(validGamified);
            Frame trainGamified = DKV.getGet(gam._output._gamTransformedTrainCenter);
            Scope.track(trainGamified);
            TestUtil.assertIdenticalUpToRelTolerance(validGamified, trainGamified, EPS);
        } finally {
            Scope.exit();
        }
    }
    /**
     * This test is to make sure that the penalty matrix is generated correctly for order = 2, 3, 4 for M-Spline.
     */
    @Test
    public void testPenaltyMatrix() {
        Scope.enter();
        try {
            Frame train = Scope.track(generateRealWithRangeOnly(4, 100, 0, 12345,
                    1));
            double[][] pctiles = new double[][]{{-1}, {-0.5}, {0}, {0.5}, {1}};             // generate knots frame
            Frame knotsFrame1 = genFrameKnots(pctiles);
            DKV.put(knotsFrame1);
            Scope.track(knotsFrame1);
            String[][] gamCols = new String[][]{{"C1"}, {"C2"}, {"C3"}};
            GAMModel.GAMParameters params = new GAMModel.GAMParameters();
            params._scale = new double[]{0.1, 0.1, 0.1};
            params._bs = new int[]{3, 3, 3};
            params._family = gaussian;
            params._response_column = "C4";
            params._spline_orders = new int[]{2,3,4};
            params._max_iterations = 1;
            params._savePenaltyMat = true;
            params._gam_columns = gamCols;
            params._knot_ids = new String[]{knotsFrame1._key.toString(), knotsFrame1._key.toString(),
                    knotsFrame1._key.toString()};
            params._train = train._key;
            params._solver = GLMModel.GLMParameters.Solver.IRLSM;
            params._keep_gam_cols = true;
            final GAMModel gam = new GAM(params).trainModel().get();
            Scope.track_generic(gam);
            // manually generating penalty matrix and check for order = 2, 3, 4
            checkOrder2PenaltyMatrix(gam);
            checkOrder3Or4PenaltyMatrix(gam, 3, new double[]{-1, -0.5, 0, 0.5, 1});
            checkOrder3Or4PenaltyMatrix(gam, 4, new double[]{-1, -0.5, 0, 0.5, 1});
        } finally {
            Scope.exit();
        }
    }

    /**
     * This test is used to test the implementation of M-spline for order = 1, 2, 3, 4.  We manually generate the
     * M-splines directly from formulae we derived.  GAM uses MSplineBasis to implement the M-spline.  In particular,
     * we want to make sure the gamification is done correctly by the model and it should equal to the ones 
     * generated by our manual M-spline
     */
    @Test
    public void testMSpline() {
        double[] testValues = DoubleStream
                .generate(ThreadLocalRandom.current()::nextDouble)
                .limit(50)
                .toArray();
        double[] knots = new double[]{0, 0.3, 0.5, 0.6, 1};
        int order = 1;
        MSplines mspline1 = new MSplines(order, knots);
        MSplineManual manualMspline1 = new MSplineManual(order, knots);
        assertCorrectMSpline(mspline1, manualMspline1, testValues);
        assertCorrectMSpline(mspline1, manualMspline1, knots);
        order = 2;
        MSplines mspline2 = new MSplines(order, knots);
        MSplineManual manualMspline2 = new MSplineManual(order, knots);
        testValues = new double[]{0.8542017241262364};
        assertCorrectMSpline(mspline2, manualMspline2, testValues);
        assertCorrectMSpline(mspline2, manualMspline2, knots);
        order = 3;
        MSplines mspline3 = new MSplines(order, knots);
        MSplineManual manualMspline3 = new MSplineManual(order, knots);
        assertCorrectMSpline(mspline3, manualMspline3, testValues);
        assertCorrectMSpline(mspline3, manualMspline3, knots);
        order = 4;
        MSplines mspline4 = new MSplines(order, knots);
        MSplineManual manualMspline4 = new MSplineManual(order, knots);
        assertCorrectMSpline(mspline4, manualMspline4, testValues);
        assertCorrectMSpline(mspline4, manualMspline4, knots);
    }

    public void assertCorrectMSpline(MSplines mspline, MSplineManual manualMspline, double[] data) {
        double[] mGamifiedValues = new double[mspline._numMBasis];
        for (int dataIndex = 0; dataIndex < data.length; dataIndex++) {
            double[] manualMGamifiedValues = manualMspline.evaluate(data[dataIndex]);
            mspline.gamifyVal(mGamifiedValues, data[dataIndex]);
            assertArrayEquals(manualMGamifiedValues, mGamifiedValues, EPS);
        }
    }

    /**
     * In this test, I compared our result to that of R
     */
    @Test
    public void testMSplineAgainstR() {
        String inputDataName = "smalldata/gam_test/zero2One.csv";
        double[] knots = new double[]{0, 0.3, 0.5, 0.6, 1.0};
        assertCorrectMSplineAgainstR(1, knots, inputDataName, "smalldata/gam_test/msGamifiedOrder1.csv");
        assertCorrectMSplineAgainstR(2, knots, inputDataName, "smalldata/gam_test/msGamifiedOrder2.csv");
        assertCorrectMSplineAgainstR(3, knots, inputDataName, "smalldata/gam_test/msGamifiedOrder3.csv");
        assertCorrectMSplineAgainstR(4, knots, inputDataName, "smalldata/gam_test/msGamifiedOrder4.csv");
        assertCorrectMSplineAgainstR(8, knots, inputDataName, "smalldata/gam_test/msGamifiedOrder8.csv");
    }

    public void assertCorrectMSplineAgainstR(int order, double[] knots, String inputFileName, String answerFrame) {
        Scope.enter();
        try {
            Frame xValues = parseAndTrackTestFile(inputFileName);
            Frame gamifiedAns = parseAndTrackTestFile(answerFrame);
            MSplines mSpline = new MSplines(order, knots);
            double[] mGamifiedValues = new double[mSpline._numMBasis];
            int nrow = (int) xValues.numRows();
            double[] ansVal = new double[mSpline._numMBasis];
            int ncol = gamifiedAns.numCols();
            for (int dataIndex = 0; dataIndex < nrow; dataIndex++) {
                double val = xValues.vec(0).at(dataIndex);
                mSpline.gamifyVal(mGamifiedValues, val);
                readDataRow(gamifiedAns, dataIndex, ncol, ansVal);
                assertArrayEquals(ansVal, mGamifiedValues, EPS);
            }
        } finally {
            Scope.exit();
        }
    }
    
    public static void readDataRow(Frame data, int rowIndex, int numCols, double[] ansRow) {
        for (int index=0; index<numCols; index++)
            ansRow[index] = data.vec(index).at(rowIndex);
    }
        
    /**
     * ISpline using manually derived formula
     */
    public class MSplineManual {
        public int _order;
        public double[] _knots;
        public int _numBasis;
        public int _totKnots;
        public MSplineBasisM[] _mSPlines;
        public MSplineManual(int order, double[] knots) {
            _order = order;
            _knots = fillKnots(knots, order);
            _totKnots = _knots.length;
            _numBasis = knots.length + order - 2;
            _mSPlines = new MSplineBasisM[_numBasis];
            for (int index = 0; index < _numBasis; index++)
                _mSPlines[index] = new MSplineBasisM(order, _knots, index);
        }

        public double[] evaluate(double val) {
            double[] gamifiedVal = new double[_numBasis];
            for (int index = 0; index < _numBasis; index++) {
                gamifiedVal[index] = gamify(val, index);
            }
            return gamifiedVal;
        }

        public double gamify(double val, int basisInd) {
            double[] knots = _mSPlines[basisInd]._knots;
            if (val < knots[0])
                return 0;
            if (val >= knots[_order])
                return 0;
            if (_order == 1) { // order of the ISpline
                if (val >= knots[0] && val < knots[1] && knots[1] != knots[0])
                    return 1.0 / (knots[1] - knots[0]);
            } else if (_order == 2) {
                if (val >= knots[0] && val < knots[1])  {
                    double temp = (knots[2]-knots[0])*(knots[1]-knots[0]);
                    if (temp != 0)
                        return 2 * (val -knots[0]) / temp;
                } 
                if (val >= knots[1] && val < knots[2]) {
                    double temp = (knots[2]-knots[1])*(knots[2]-knots[0]);
                    if (temp != 0)
                    return 2*(knots[2]-val)/temp;
                }
            } else if (_order == 3) {
                if (val >= knots[0] && val < knots[1]) {
                    double temp = (knots[3]-knots[0])*(knots[2]-knots[0])*(knots[1]-knots[0]);
                    if (temp != 0)
                        return 3*(val-knots[0])*(val-knots[0])/temp;
                }
                if (val >= knots[1] && val < knots[2]) {
                    double temp = (knots[2]-knots[1])*(knots[2]-knots[0])*(knots[3]-knots[0]);
                    double temp1 = (knots[3]-knots[1])*(knots[2]-knots[1])*(knots[3]-knots[0]);
                    double sumVal = 0;
                    if (temp != 0)
                        sumVal += 3*(val-knots[0])*(knots[2]-val)/temp;
                    if (temp1 != 0)
                        sumVal += 3*(knots[3]-val)*(val-knots[1])/temp1;
                    return sumVal;
                }
                if (val >= knots[2] && val < knots[3]) {
                    double temp = (knots[3]-knots[0])*(knots[3]-knots[2])*(knots[3]-knots[1]);
                    if (temp != 0)
                        return 3*(knots[3]-val)*(knots[3]-val)/temp;
                }
            } else if (_order == 4) {
                if (val >= knots[0] && val < knots[1]) {
                    double temp = (knots[3]-knots[0])*(knots[2]-knots[0])*(knots[1]-knots[0])*(knots[4]-knots[0]);
                    if (temp != 0)
                        return 4*(val-knots[0])*(val-knots[0])*(val-knots[0])/temp;
                }
                if (val >= knots[1] && val < knots[2]) {
                    double sumVal = 0.0;
                    double temp = (knots[3]-knots[0])*(knots[2]-knots[0])*(knots[2]-knots[1])*(knots[4]-knots[0]);
                    double temp1 = (knots[3]-knots[0])*(knots[3]-knots[1])*(knots[2]-knots[1])*(knots[4]-knots[0]);
                    double temp2 = (knots[4]-knots[1])*(knots[3]-knots[1])*(knots[2]-knots[1])*(knots[4]-knots[0]);
                    if (temp != 0)
                        sumVal += 4*(val-knots[0])*(val-knots[0])*(knots[2]-val)/temp;
                    if (temp1 != 0)
                        sumVal += 4*(val-knots[0])*(val-knots[1])*(knots[3]-val)/temp1;
                    if (temp2 != 0)
                        sumVal += 4*(knots[4]-val)*(val-knots[1])*(val-knots[1])/temp2;
                    return sumVal;
                }
                if (val >= knots[2] && val < knots[3]) {
                    double sumVal = 0.0;
                    double temp = (knots[3]-knots[0])*(knots[3]-knots[1])*(knots[3]-knots[2])*(knots[4]-knots[0]);
                    double temp1 = (knots[4]-knots[1])*(knots[3]-knots[1])*(knots[3]-knots[2])*(knots[4]-knots[0]);
                    double temp2 = (knots[4]-knots[1])*(knots[4]-knots[2])*(knots[3]-knots[2])*(knots[4]-knots[0]);
                    if (temp != 0)
                        sumVal += 4*(val-knots[0])*(knots[3]-val)*(knots[3]-val)/temp;
                    if (temp1 != 0)
                        sumVal += 4*(knots[4]-val)*(val-knots[1])*(knots[3]-val)/temp1;
                    if (temp2 != 0)
                        sumVal += 4*(knots[4]-val)*(knots[4]-val)*(val-knots[2])/temp2;
                    return sumVal;
                }
                if (val >= knots[3] && val < knots[4]) {
                    double temp = (knots[4]-knots[0])*(knots[4]-knots[1])*(knots[4]-knots[2])*(knots[4]-knots[3]);
                    if (temp != 0)
                        return 4*(knots[4]-val)*(knots[4]-val)*(knots[4]-val)/temp;
                }
            }
            return 0;
        }
        
        public class MSplineBasisM {
            public int _order;
            public double[] _knots;
            public int _basisInd;

            public MSplineBasisM(int order, double[] knots, int index) {
                _order = order;
                _basisInd = index;
                _knots = extractKnots(index, order, knots);
            }
        }
    }
    
    public void checkOrder3Or4PenaltyMatrix(GAMModel gam, int order, double[] knots) {
        double[][] manualPenalty = order==3 ? manualPenalty3Order(knots, order):manualPenalty4Order(knots, order);
        double[][] modelPenalty = order==3 ? gam._output._penaltyMatrices[1] : gam._output._penaltyMatrices[2];
        double penaltyScale = order==3 ? 1.0 / gam._output._penaltyScale[1] : 1.0 / gam._output._penaltyScale[2];
        ArrayUtils.mult(manualPenalty, penaltyScale);
        assert2DArrayEqual(manualPenalty, modelPenalty);
    }
    
    public void checkOrder2PenaltyMatrix(GAMModel gam) {
        // second derivative of polynomials ax+b will result in all zeros
        double[][] penaltyMat2 = new double[][]{{0,0,0,0,0}, {0,0,0,0,0}, {0,0,0,0,0}, {0,0,0,0,0}, {0,0,0,0,0}};
        assert2DArrayEqual(penaltyMat2, gam._output._penaltyMatrices[0]);
    }
    
    double[][] manualPenalty3Order(double[] knots, int order) {
        int numBasis = knots.length+order-2;
        double[] totKnots = fillKnots(knots, order);
        double[][] basis2ndDeriv = mSpline3Order2ndDeriv(knots, order);  // note that there are intervals with 0 value
        double[][] penaltyMatrix = new double[numBasis][numBasis];
        for (int index=0; index<numBasis; index++) {
            double[] firstBasis = basis2ndDeriv[index];
            for (int index2=index; index2<numBasis; index2++) {
                double[] secondBasis = basis2ndDeriv[index2];
                penaltyMatrix[index][index2] = integrateMult3Order(firstBasis, index, secondBasis, index2, totKnots);
                penaltyMatrix[index2][index] = penaltyMatrix[index][index2];
            }
        }
        return penaltyMatrix;
    }
    
    double[][] manualPenalty4Order(double[] knots, int order) {
        int numBasis = knots.length+order-2;
        double[] totKnots = fillKnots(knots, order);
        double[][][] basis2ndDeriv = splines4thOrder2ndDeriv(knots, order);  // note that there are intervals with 0 value
        double[][] penaltyMatrix = new double[numBasis][numBasis];
        for (int index=0; index<numBasis; index++) {
            double[][] firstBasis = basis2ndDeriv[index];
            for (int index2=index; index2<numBasis; index2++) {
                double[][] secondBasis = basis2ndDeriv[index2];
                penaltyMatrix[index][index2] = integrateMult4Order(firstBasis, index, secondBasis, index2, totKnots);
                penaltyMatrix[index2][index] = penaltyMatrix[index][index2];
            }
        }
        return penaltyMatrix;
    }
    public static double integrateMult4Order(double[][] firstBasis, int firstIndex, double[][] secondBasis, int secondIndex,
                                             double[] totKnots) {
        // multiply the two basis functions over knots
        int basisLen = firstBasis.length;
        List<Integer> knotInd1 = IntStream.range(firstIndex, firstIndex + basisLen).boxed().collect(Collectors.toList());
        List<Integer> knotInd2 = IntStream.range(secondIndex, secondIndex + basisLen).boxed().collect(Collectors.toList());
        List<Double[]> products = new ArrayList<>();
        List<Integer> knotIndices = new ArrayList<>();
        Double[] coeffs = new Double[3];
        double[] oneFirstBasis;
        double[] oneSecondBasis;
        for (int index = 0; index < basisLen; index++) {
            int knot1 = knotInd1.get(index);
            int knot2 = knotInd2.indexOf(knot1);
            if (knot2 >= 0) {
                oneFirstBasis = firstBasis[index];
                oneSecondBasis = secondBasis[knot2];
                multiplyBasis(oneFirstBasis, oneSecondBasis, coeffs);
                if (!(coeffs[0]==0 && coeffs[1]==0 && coeffs[2]==0)) {
                    products.add(coeffs.clone());
                    knotIndices.add(knot1);
                }
            }
        }
        // perform integration of here
        double intResult = 0.0;
        basisLen = products.size();
        int knotIndex;
        int knotIndexP1;
        double knotDiff;
        for (int index=0; index<basisLen; index++) {
            coeffs = products.get(index);
            knotIndex = knotIndices.get(index);
            knotIndexP1 = knotIndex+1;
            knotDiff = totKnots[knotIndex+1]-totKnots[knotIndex];
            intResult += coeffs[2]*(totKnots[knotIndexP1]*totKnots[knotIndexP1]*totKnots[knotIndexP1]-
                    totKnots[knotIndex]*totKnots[knotIndex]*totKnots[knotIndex])/3+
                    coeffs[1]*(totKnots[knotIndexP1]*totKnots[knotIndexP1]-totKnots[knotIndex]*totKnots[knotIndex])/2+coeffs[0]*knotDiff;
        }
        return intResult;
    }

    /**
     * Perform multiplication of the following polynomial form: (ax+b)*(cx+d)
     */
    public static void multiplyBasis(double[] first, double[] second, Double[] product) {
        product[0] = first[0] * second[0];
        product[2] = first[1] * second[1];
        product[1] = first[0] * second[1] + first[1] * second[0];
    }
    
    public static double integrateMult3Order(double[] firstBasis, int firstIndex, double[] secondBasis, int secondIndex,
                                             double[] totKnots) {
        // multiply the two basis functions over knots
        int basisLen = firstBasis.length;
        List<Integer> knotInd1 = IntStream.range(firstIndex, firstIndex+basisLen).boxed().collect(Collectors.toList());
        List<Integer> knotInd2 = IntStream.range(secondIndex, secondIndex+basisLen).boxed().collect(Collectors.toList());   
        List<Double> products = new ArrayList<>();
        List<Integer> knotIndices = new ArrayList<>();
        for (int index=0; index < basisLen; index++) {
            int knot1 = knotInd1.get(index);
            int knot2 = knotInd2.indexOf(knot1);
            if (knot2 >= 0) {
                if (firstBasis[index] != 0  && secondBasis[knot2] != 0) {
                    products.add(firstBasis[index] * secondBasis[knot2]);
                    knotIndices.add(knot1);
                }
            }
        }
        
        // integrate over correct time intervals where the products are not zero.
        double intResult = 0.0;
        int intLen = products.size();
        for (int index=0; index<intLen; index++) {
            intResult += products.get(index)*(totKnots[knotIndices.get(index)+1]-totKnots[knotIndices.get(index)]);
        }
        return intResult;
    }
    
    double[][][] splines4thOrder2ndDeriv(double[] knots, int order) {
        double[] totKnots = fillKnots(knots, order);
        int numBasis = knots.length+order-2;
        double[][][] basis4Order2ndDeriv = new double[numBasis][][];
        for (int index=0; index<numBasis; index++) 
            basis4Order2ndDeriv[index] = m4Order2ndDeriv(extractKnots(index, order, totKnots), order);
        return basis4Order2ndDeriv;
    }

    /**
     * @param knots : knots without duplication
     */
    double[][] mSpline3Order2ndDeriv(double[] knots, int order) {
        double[] totKnots = fillKnots(knots, order);
        int numBasis = knots.length+order-2;
        double[][] basis3Order2ndDeriv = new double[numBasis][];
        for (int index=0; index<numBasis; index++) 
            basis3Order2ndDeriv[index] = m3Order2ndDeriv(extractKnots(index, order, totKnots), order);
        return basis3Order2ndDeriv;
    }
    
    /**
     * return double[][], first dimension is for number of basis functions, 2nd array is for coefficients of a+bt where
     * a is for degree 0, second one is for degree 1.
     */
    public static double[][] m4Order2ndDeriv(double[] knots, int order) {
        double[][] secondDeriv = new double[order][2];
        double temp, temp1, temp2;
        temp = (knots[3]-knots[0])*(knots[2]-knots[0])*(knots[1]-knots[0])*(knots[4]-knots[0]);
        if (temp!=0) {
            secondDeriv[0][0] = -24.0*knots[0] / temp;
            secondDeriv[0][1] = 24.0/temp;
        }
        temp = (knots[3]-knots[0])*(knots[2]-knots[0])*(knots[2]-knots[1])*(knots[4]-knots[0]);
        temp1 = (knots[3]-knots[0])*(knots[3]-knots[1])*(knots[2]-knots[1])*(knots[4]-knots[0]);
        temp2 = (knots[4]-knots[1])*(knots[3]-knots[1])*(knots[2]-knots[1])*(knots[4]-knots[0]);
        secondDeriv[1][0] = (temp==0?0:(8.0*knots[2]+16*knots[0])/temp)+(temp1==0?0:(8*knots[3]+8*knots[0]+8*knots[1])/temp1)+
                (temp2==0?0:(8*knots[4]+16*knots[1])/temp2);
        secondDeriv[1][1] = (temp==0?0:-24.0/temp)-(temp1==0?0:24/temp1)-(temp2==0?0:24/temp2);
     
        temp = (knots[3]-knots[0])*(knots[3]-knots[1])*(knots[3]-knots[2])*(knots[4]-knots[0]);
        temp1 = (knots[4]-knots[1])*(knots[3]-knots[1])*(knots[3]-knots[2])*(knots[4]-knots[0]);
        temp2 = (knots[4]-knots[1])*(knots[4]-knots[2])*(knots[3]-knots[2])*(knots[4]-knots[0]);
        secondDeriv[2][0] = (temp==0?0:(-16.0*knots[3]-8*knots[0])/temp) + (temp1==0?0:(-8*knots[3]-8.0*knots[4]-8*knots[1])/temp1)+
                (temp2==0?0:(-16*knots[4]-8*knots[2])/temp2);
        secondDeriv[2][1] = (temp==0?0:24/temp)+(temp1==0?0:24/temp1)+(temp2==0?0:24.0/temp2);
        temp = (knots[4]-knots[0])*(knots[4]-knots[1])*(knots[4]-knots[2])*(knots[4]-knots[3]);
        if (temp != 0) {
            secondDeriv[3][0] = 24.0*knots[4]/temp;
            secondDeriv[3][1] = -24.0/temp;
        }
        return secondDeriv;
    }
    
    /**
     * @param knots : only contains knots over which basis function is active
     */
    public static double[] m3Order2ndDeriv(double[] knots, int order) {
        double[] secondDeriv = new double[order];
        double temp = (knots[3]-knots[0])*(knots[2]-knots[0])*(knots[1]-knots[0]);
        if (temp!=0)
            secondDeriv[0] = 6.0/temp;
        temp = (knots[2]-knots[1])*(knots[2]-knots[0])*(knots[3]-knots[0]);
        double temp2 = (knots[3]-knots[1])*(knots[2]-knots[1])*(knots[3]-knots[0]);
        if (!(temp==0 && temp2==0))
            secondDeriv[1] = -6.0/temp-6.0/temp2;
        temp = (knots[3]-knots[0])*(knots[3]-knots[2])*(knots[3]-knots[1]);
        if (temp != 0)
            secondDeriv[2] = 6.0/temp;
        return secondDeriv;
    }
}
